from torchvision import transforms
from torchvision import datasets


data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪
                                     transforms.RandomHorizontalFlip(),  # 随机水平翻转
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    

train_data_set = datasets.CIFAR10(
        root="../../data_set",   # -----指定数据存放路径
        train=True, 
        transform= data_transform["train"], 
        download=True)        # True:没有下载过，会先下载数据